%% Clean up
clear variables

%% Main
% Aux
aux.tol_J   = 10^(-16);
aux.n_m     = 10^3;

%% Specification
days        = 26;% nr of days
dt          = 1/days/10;% set time steps per month
age_limit   = 20;% maximum age (years)
drift_max   = 2;% maximum monthly employment drift
n_sample    = 10^7;% number of firms
death_rate  = 0.02348;% annual death rate

%% Paths
path     = 'parameters_22\';
datapath = 'Output\';
filename = 'Chains.xlsx';
name     = 'replication';

load([path,'\param_ini_',name,'.mat']);
load([path,'\param_end_',name,'.mat']);
param=param_ini;

param.d = death_rate/12;% monthly max drift rate
param.drift_max = drift_max;

%% Create drift of m
param.mm   = [-100, param.m_l-10^(-10), param.m_l, param.m_h - 10^(-9), ...
    (param.m_h - 10^(-10)) + (param.m_u - 10^(-5)-(param.m_h - 10^(-10))).*(1 - linspace( 1, 0, 10^2).^3), param.m_u,param.m_u + 100];
param.m_max = param.m_e;
param.drift = (param.mu-param.sigma^2/2-(1-param.alpha).*emp_drift(param.mm,param));% drift for ln(m)

%% Create functions for drift and vairance for process for ln(m)
F = @(t,lnm) (interp1(param.mm,param.drift, exp(lnm),'linear'));%
G = @(t,lnm) param.sigma;

mm  = exp( linspace( log(param.m_l), log(param.m_u), aux.n_m));
G_ini = (q_fun(mm,param)-q_fun(param.m_l,param))./(q_fun(param.m_u,param)-q_fun(param.m_l,param));

rng(10^8);
age_max = 12*age_limit;
F_tmp=rand(n_sample,1);
F_tmp_a=rand(n_sample,1);
F_max=1-exp(-age_max*param.d);
F_tmp_a=F_tmp_a.*F_max;

Age_end = -log(1-F_tmp_a)./param.d;

N_ini = nan(n_sample,1);
N_end = nan(n_sample,1);

sc = parallel.pool.Constant(RandStream('Threefry'));

tic
parfor (i = 1:n_sample)

    stream = sc.Value;
    set(stream,'Substream',i);
    prev = RandStream.setGlobalStream(stream);
    
    m_ini=interp1(G_ini,mm,F_tmp(i),'linear');
	n_ini = 20;

    N_ini(i) = n_ini;
    M_ini(i) = m_ini;

    nPeriods=floor(Age_end(i)/dt) + 1;
    
    lnm_ini=log(m_ini);   
    SDE = sde(F, G, 'StartState', lnm_ini);
    [lnm, t] = SDE.simulate(nPeriods, 'DeltaTime', dt,'nTrials',1);    
    m=exp(lnm);   
    
    dlnn=(param.mu-param.sigma^2/2-F(t, lnm))./(1-param.alpha);%drift in employment in logs
    n=n_ini.*exp(cumsum(dlnn.*dt));
    
    N_end(i)=n(end);
    M_end(i)=m(end);
    
end

toc

figure(1)
hold off
[M_ini_tmp, ind] = sort(M_ini);
[N_ini_tmp] = N_ini(ind);
G=cumsum(N_ini_tmp)./sum(N_ini_tmp);
plot(M_ini_tmp,G)

hold on
plot(mm,G_ini)

[M_end_tmp, ind]=sort(M_end);
[N_end_tmp]=N_end(ind);
G=cumsum(N_end_tmp)./sum(N_end_tmp);
plot(M_end_tmp,G)

emp   = [0
5
10
15
20
25
30
35
40
50
75
100
150
200
300
400
500
750
1000
1500
2000
2500
5000
max(N_end)+1]';

firm_share = sum((N_end>emp(1:end-1)).*(N_end<=emp(2:end)));
remaining=cumsum(firm_share)/length(N_end);
figure(2)
plot(log(emp(2:end-1)),log(1-remaining(2:end)),'o')

firm_data.M_end=M_end;
firm_data.Age_end=Age_end;
firm_data.N_end=N_end;
save([path,'firm_data.mat'],'firm_data');

varNames = {'firm_size_low', 'firm_size_high', ...
    'share' };

T = table(emp(1:end-1)', emp(2:end)', ...
        firm_share', ...
         'VariableNames',varNames);

writetable(T,[datapath,filename],'Sheet','Figure D','Range','A1')             


if 2==1
    %% Save raw data 
    M_end=M_end(~isnan(N_end));
    H_end=H_end(~isnan(N_end));
    E_end=E_end(~isnan(N_end));
    Q_end=Q_end(~isnan(N_end));
    L_end=L_end(~isnan(N_end));
    time_end=time_end(~isnan(N_end));
    Age_end=Age_end(~isnan(N_end));
    N_end=N_end(~isnan(N_end));
    
    filename = ['Chains.xlsx'];


    varNames = {'employment', 'age', ...
        'marginal_product' };

    T = table(N_end, Age_end, ...
            M_end', ...
             'VariableNames',varNames);

    writetable(T,[datapath,filename],'Sheet','Figure D','Range','A1')             

end

